{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d2b61e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.spatial import distance\n",
    "import plotly.express as px\n",
    "import matplotlib.pyplot as plt\n",
    "from plotly.subplots import make_subplots\n",
    "import plotly.graph_objects as go\n",
    "import torch\n",
    "import pdb\n",
    "import os\n",
    "from os.path import join\n",
    "import sys\n",
    "sys.path.append(\"./utils\")\n",
    "from LUT import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5358d7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "root = \"/mnt/tvmn1/input\"\n",
    "name = \"Test+10-1-1_models/\"\n",
    "epoch = 335\n",
    "model = torch.load(os.path.join(root, name, \"model{:0>4}.pth\".format(epoch)),map_location=torch.device('cpu'))\n",
    "luts = cube_to_lut(model['LUT_model.LUTs'].reshape(-1,3,33,33,33))\n",
    "idt = identity3d_tensor(33)\n",
    "luts += idt.unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c6aed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 33    \n",
    "x, y, z = np.arange(0,dim), np.arange(0,dim), np.arange(0,dim)\n",
    "xx, yy, zz = np.meshgrid(x, y, z)\n",
    "coords = np.array((xx.ravel(), yy.ravel(), zz.ravel())).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c11e1de",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = make_subplots(rows=1, cols=2, specs=[[{\"type\": \"scene\"}, {\"type\": \"scene\"}]])\n",
    "fig.add_trace(\n",
    "    go.Scatter3d(x=coords[:, 2],\n",
    "                 y=coords[:, 0],\n",
    "                 z=coords[:, 1],\n",
    "                 mode='markers',\n",
    "                 marker=dict(\n",
    "                    color=idt.reshape(3,-1).transpose(0,1),\n",
    "                    opacity=0.5,\n",
    "                 )),\n",
    "    row=1, col=1\n",
    ")\n",
    "lut = luts[3].reshape(3,-1).transpose(0,1)\n",
    "fig.add_trace(\n",
    "    go.Scatter3d(x=lut[:, 0],\n",
    "                 y=lut[:, 1],\n",
    "                 z=lut[:, 2],\n",
    "                 mode='markers',\n",
    "                 marker=dict(\n",
    "                    color=idt.reshape(3,-1).transpose(0,1),\n",
    "                    opacity=0.5,\n",
    "                 )),\n",
    "    row=1, col=2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10a2511b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = make_subplots(rows=1, cols=2, specs=[[{\"type\": \"scene\"}, {\"type\": \"scene\"}]])\n",
    "fig.add_trace(\n",
    "    go.Scatter3d(x=coords[:, 2],\n",
    "                 y=coords[:, 0],\n",
    "                 z=coords[:, 1],\n",
    "                 mode='markers',\n",
    "                 marker=dict(\n",
    "                    color=idt.reshape(3,-1).transpose(0,1),\n",
    "                    opacity=0.5,\n",
    "                 )),\n",
    "    row=1, col=1\n",
    ")\n",
    "lut = luts[4].reshape(3,-1).transpose(0,1)\n",
    "fig.add_trace(\n",
    "    go.Scatter3d(x=lut[:, 0],\n",
    "                 y=lut[:, 1],\n",
    "                 z=lut[:, 2],\n",
    "                 mode='markers',\n",
    "                 marker=dict(\n",
    "                    color=idt.reshape(3,-1).transpose(0,1),\n",
    "                    opacity=0.5,\n",
    "                 )),\n",
    "    row=1, col=2\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
